import logging
import warnings
import numpy as np

import efel
from efel.pyfeatures import get_cpp_feature
from efel.pyfeatures.pyfeatures import _get_cpp_data
from sklearn.cluster import KMeans

logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)


def _get_burst_thresh(isis, max_isis=50.0):
    """Find a split of isis for inter and intra bursts.

    The high number of n_init increases stability of the algorithm.
    """
    kmeans = KMeans(n_clusters=2, n_init=100).fit(isis.reshape(len(isis), 1))
    return min(max_isis, kmeans.cluster_centers_.mean(axis=0)[0])


def burst_runaway(
    max_isis: float = 50.0, raise_warnings: bool = False, max_runaway: float = 10.0
) -> np.ndarray:
    """Measure of runaway bursting, based on differences between first and last AHP depths.
    A small values, usually less than 0.2 corresponds to a runnaway burst.
    TODO: add more details on how we compute this.
    """
    stim_start = _get_cpp_data("stim_start")
    stim_end = _get_cpp_data("stim_end")
    peak_times = get_cpp_feature("peak_time")

    # if we have no or one spike, there cannot be a burst
    if peak_times is None or len(peak_times) == 1:
        return np.array([max_runaway])

    peak_times = peak_times[(peak_times > stim_start) & (peak_times < stim_end)]
    isis = np.diff(peak_times)

    # if we have only two large isis, we assume it is not a burst
    if len(isis) < 2:
        return np.array([max_runaway])

    # if all isis are less than max_isis, we assume it is a single burst
    if max(isis) < max_isis:
        return np.array([max_runaway])
    thresh = _get_burst_thresh(isis)

    # if more that 10% of isis in the left group larger than max_isis, it is not bursting
    small_isis = isis[isis < thresh]
    if len(small_isis[small_isis > max_isis]) > 0.1 * len(isis):
        return np.array([max_runaway])

    # if the smallest of right group is to large, it is not bursting
    if len(isis[isis > thresh]) and min(isis[isis > thresh]) > 2000:
        return np.array([max_runaway])

    # here we check is the gap between the two group of ISIs is big enough
    # to be considered a burst behaviour, the 1.2 and 0.8 are fairly arbitrary
    if len(isis[(isis < 1.2 * thresh) & (isis > 0.8 * thresh)]) > 0.1 * len(isis):
        if raise_warnings:
            warnings.warn(
                """While calculating all_burst_number,
                there are spike around the threshold, we return 0 bursts""",
                RuntimeWarning,
            )

        return np.array([max_runaway])
    voltage = get_cpp_feature("voltage")
    time = get_cpp_feature("time")

    n_tonic = 0
    for i, isi in enumerate(isis):
        if isi > thresh:
            time_shift = 5
            lowest = min(voltage[(time > peak_times[i]) & (time < peak_times[i + 1])])
            slow_lowest = min(
                voltage[(time > peak_times[i] + time_shift) & (time < peak_times[i + 1])]
            )
            if lowest < slow_lowest and isis[i - 1] > thresh:
                n_tonic += 1

    voltages = []
    times = []
    for i in np.argwhere(isis > thresh).flatten():
        if i > len(isis) - n_tonic:
            # stop at the end of the last burst before tonic
            break
        mask = (time >= peak_times[i]) & (time < peak_times[i + 1])
        voltages.append(min(voltage[mask]))
        times.append(time[mask][np.argmin(voltage[mask])])

    if len(voltages) < 4:
        return np.array([max_runaway])

    # we use the second burst ahp, to get rid of any transient behavior in the first burst
    print(voltages, times)
    return np.array(
        [min(max_runaway, (voltages[-2] - voltages[1]) / (times[-2] - times[1]) * 1000.0)]
    )


def all_burst_number(max_isis=50.0):
    """The number of all the bursts, even if they have a single AP.

    Instead of relying on burst_mean_freq, we split the ISIs into two groups,
    and count the number of large ISIs. If there are no distinct two groups,
    there are no bursts. The groups may not be even, which would correspond to
    bursts with equal number of APs.
    If there is an initial burst, then regular spiking, this feature will return the
    number of AP minus the number of additional spikes in the initial burst.
    """
    stim_start = _get_cpp_data("stim_start")
    stim_end = _get_cpp_data("stim_end")
    peak_times = get_cpp_feature("peak_time")

    # if we have no or one spike, there cannot be a burst
    if peak_times is None or len(peak_times) == 1:
        return np.array([0])

    peak_times = peak_times[(peak_times > stim_start) & (peak_times < stim_end)]
    isis = np.diff(peak_times)

    # if all isis are less than max_isis, we assume it is a single burst
    if max(isis) < max_isis:
        return np.array([1])

    # if we have only two large isis, we assume it is not a burst
    if len(isis) < 2:
        return np.array([0])

    thresh = _get_burst_thresh(isis)

    voltage = get_cpp_feature("voltage")
    time = get_cpp_feature("time")

    n_tonic = 0
    for i, isi in enumerate(isis):
        if isi > thresh:
            time_shift = 5
            lowest = min(voltage[(time > peak_times[i]) & (time < peak_times[i + 1])])
            slow_lowest = min(
                voltage[(time > peak_times[i] + time_shift) & (time < peak_times[i + 1])]
            )
            if lowest < slow_lowest and isis[i - 1] > thresh:
                n_tonic += 1

    # if the smallest of right group is to large, it is not bursting
    if min(isis[isis > thresh]) > 4000:
        return np.array([0])

    if n_tonic:
        return np.array([len(isis[isis > thresh]) - n_tonic])
    return np.array([len(isis[isis > thresh]) + 1])


def tonic_after_burst(max_isis: float = 50.0) -> np.ndarray:
    """Computes the number of spikes in the tonic phase or bursting, returns 0 if no tonic.

    This is based on all_burst_number feature.
    """
    stim_start = _get_cpp_data("stim_start")
    stim_end = _get_cpp_data("stim_end")
    peak_times = get_cpp_feature("peak_time")

    # if we have no or one spike, there cannot be a burst
    if peak_times is None or len(peak_times) == 1:
        return np.array([0])

    peak_times = peak_times[(peak_times > stim_start) & (peak_times < stim_end)]
    isis = np.diff(peak_times)

    # if all isis are less than max_isis, we assume it is a single burst
    if max(isis) < max_isis:
        return np.array([0])

    # if we have only two large isis, we assume it is not a burst
    if len(isis) < 2:
        return np.array([0])

    thresh = _get_burst_thresh(isis)

    voltage = get_cpp_feature("voltage")
    time = get_cpp_feature("time")
    n_tonic = 0
    for i, isi in enumerate(isis):
        if isi > thresh:
            time_shift = 5
            lowest = min(voltage[(time > peak_times[i]) & (time < peak_times[i + 1])])
            slow_lowest = min(
                voltage[(time > peak_times[i] + time_shift) & (time < peak_times[i + 1])]
            )
            if lowest < slow_lowest and isis[i - 1] > thresh:
                n_tonic += 1

    return np.array([n_tonic + 1 if n_tonic else 0])


def voltage_std():
    """Get standard deviation of voltage between stimulus interval."""
    voltage = get_cpp_feature("voltage")
    time = get_cpp_feature("time")
    stim_start = _get_cpp_data("stim_start")
    stim_end = _get_cpp_data("stim_end")
    return np.array([np.std(voltage[(time > stim_start) & (time < stim_end)])])


efel.register_feature(all_burst_number)
efel.register_feature(tonic_after_burst)
efel.register_feature(voltage_std)
efel.register_feature(burst_runaway)
